/*
 * Copyright 2020 eBlocker Open Source UG (haftungsbeschraenkt)
 *
 * Licensed under the EUPL, Version 1.2 or - as soon they will be
 * approved by the European Commission - subsequent versions of the EUPL
 * (the "License"); You may not use this work except in compliance with
 * the License. You may obtain a copy of the License at:
 *
 *   https://joinup.ec.europa.eu/page/eupl-text-11-12
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" basis,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */
package org.eblocker.server.common.malware;

import org.eblocker.server.common.data.DataSource;
import org.eblocker.server.common.data.systemstatus.SubSystem;
import org.eblocker.server.common.network.unix.IpSetConfig;
import org.eblocker.server.common.network.unix.IpSets;
import org.eblocker.server.common.startup.SubSystemInit;
import org.eblocker.server.common.startup.SubSystemService;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import com.google.inject.name.Named;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;

@Singleton
@SubSystemService(value = SubSystem.NETWORK_STATE_MACHINE, initPriority = 50)
public class MalwareFilterService {
    private static final Logger log = LoggerFactory.getLogger(MalwareFilterService.class);

    private final Path malwareUrlsFilePath;
    private final Path malwareIpsFilePath;
    private final IpSetConfig ipSetConfig;
    private DataSource dataSource;
    private final IpSets ipSets;
    private final ObjectMapper objectMapper;

    private boolean enabled;
    private long lastUpdate;
    private TreeMap<String, String[]> malwareByUrl = new TreeMap<>();

    @Inject
    public MalwareFilterService(@Named("malware.filter.urls.file.path") String malwareUrlsFilePath,
                                @Named("malware.filter.ips.file.path") String malwareIpsFilePath,
                                @Named("malware.filter.ipset") IpSetConfig ipSetConfig,
                                DataSource dataSource,
                                IpSets ipSets,
                                ObjectMapper objectMapper) {
        this.malwareUrlsFilePath = Paths.get(malwareUrlsFilePath);
        this.malwareIpsFilePath = Paths.get(malwareIpsFilePath);
        this.ipSetConfig = ipSetConfig;
        this.dataSource = dataSource;
        this.ipSets = ipSets;
        this.objectMapper = objectMapper;
    }

    @SubSystemInit
    public void init() throws IOException {
        enabled = dataSource.isMalwareUrlFilterEnabled();

        if (ipSets.isSupportedByOperatingSystem()) {
            ipSets.createIpSet(ipSetConfig);
        }
        checkUpdate();
    }

    public boolean isBlocked(String url) {
        if (!enabled) {
            return false;
        }

        String normalizedUrl = MalwareUtils.normalize(url);
        String floorKey = malwareByUrl.floorKey(normalizedUrl);
        if (floorKey == null) {
            return false;
        }
        return normalizedUrl.startsWith(floorKey);
    }

    public List<String> getMalwareByUrl(String url) {
        String normalizedUrl = MalwareUtils.normalize(url);
        Map.Entry<String, String[]> floorEntry = malwareByUrl.floorEntry(normalizedUrl);
        if (normalizedUrl.startsWith(floorEntry.getKey())) {
            return Arrays.asList(floorEntry.getValue());
        }
        return Collections.emptyList();
    }

    public boolean isEnabled() {
        return enabled;
    }

    public boolean setEnabled(boolean enabled) {
        if (this.enabled != enabled) {
            try {
                this.enabled = enabled;
                dataSource.setMalwareUrlFilterEnabled(enabled);
                updateUrls();
                updateFirewall();
            } catch (IOException e) {
                log.error("failed to set malware url filter enabled to " + enabled, e);
            }
        }
        return enabled;
    }

    public void checkUpdate() {
        try {
            long fileModificationTime = Math.max(Files.getLastModifiedTime(malwareUrlsFilePath).toMillis(), Files.getLastModifiedTime(malwareIpsFilePath).toMillis());
            if (lastUpdate < fileModificationTime) {
                log.debug("file change detected, updating filter");
                updateUrls();
                updateFirewall();
                lastUpdate = fileModificationTime;
            } else {
                log.debug("no modifications");
            }
        } catch (IOException e) {
            log.warn("failed to update malware filter", e);
        }
    }

    public long getLastUpdate() {
        return lastUpdate;
    }

    private void updateUrls() throws IOException {
        if (!enabled) {
            activateEntries(Collections.emptyList());
            return;
        }

        List<MalwareEntry> entries = objectMapper.readValue(malwareUrlsFilePath.toFile(), new TypeReference<List<MalwareEntry>>() {});
        entries = optimizeEntries(entries);
        activateEntries(entries);
        log.info("filtering {} malware urls", malwareByUrl.size());
    }

    private List<MalwareEntry> optimizeEntries(List<MalwareEntry> entries) {
        List<MalwareEntry> optimizedEntries = new ArrayList<>(entries.size());
        Map<Set<String>, String[]> sharedArrays = new HashMap<>();
        for(MalwareEntry e : entries) {
            Set<String> key = new HashSet<>(Arrays.asList(e.getHostedMalware()));
            if (!sharedArrays.containsKey(key)) {
                sharedArrays.put(key, e.getHostedMalware());
            }
            optimizedEntries.add(new MalwareEntry(e.getUrl(), sharedArrays.get(key)));
        }
        return optimizedEntries;
    }

    private void activateEntries(List<MalwareEntry> entries) {
        TreeMap<String, String[]> newMalwareByUrl = new TreeMap<>();
        entries.forEach(e -> newMalwareByUrl.put(e.getUrl(), e.getHostedMalware()));
        this.malwareByUrl = newMalwareByUrl;
    }

    private void updateFirewall() throws IOException {
        if (ipSets.isSupportedByOperatingSystem()) {
            if (enabled) {
                Map<String, List<Integer>> portsByIp = objectMapper.readValue(malwareIpsFilePath.toFile(), new TypeReference<Map<String, List<Integer>>>() {
                });
                Set<String> entries = portsByIp.entrySet().stream()
                    .flatMap(e -> e.getValue().stream().map(p -> e.getKey() + ",tcp:" + p))
                    .collect(Collectors.toSet());
                ipSets.updateIpSet(ipSetConfig, entries);
            } else {
                ipSets.updateIpSet(ipSetConfig, Collections.emptySet());
            }
        }
    }

}
